-
Notifications
You must be signed in to change notification settings - Fork 416
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Ascend NPU as a backend #1826
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1826
Note: Links to docs will display an error until the docs builds have been completed. ❌ 4 New Failures, 2 Cancelled JobsAs of commit 99a6dd8 with merge base e99b890 (): NEW FAILURES - The following jobs have failed:
CANCELLED JOBS - The following jobs were cancelled. Please retry:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ca78eca
to
b6332dd
Compare
Hi @ebsmothers, @RdoubleA: I hope you’re doing well! Could you please help me review my code? I would really appreciate it if you could take a look and share any feedback or suggestions. Thank you so much in advance for your time and support! 😊 Best regards |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @noemotiovon thanks for the PR! And apologies for the delay in getting to the review here. A couple other questions I have that don't really fit neatly anywhere inline:
- Do we expect compile to work? If so, we should test that. If not, we could raise an error
- Do we expect quant-related APIs (e.g. QLoRA or QAT) from torchao to work? Same as point 1: if so we should test or possibly raise an error
- PyTorch has now released 2.5 as stable. In general we do not claim to support anything but the latest stable release of PyTorch -- do you know the contract on torch_npu releases here?
distributed training seems to have problems e.g qat_distributed @noemotiovon |
I would be very happy to! I will contact you via email. |
@noemotiovon through 126 email thanks. Looking forward to your email. |
b6332dd
to
1aad0d2
Compare
Basic Usage Test
|
Basic Usage Test
|
Hi @ebsmothers, Thank you very much for reviewing my code! |
ecbacac
to
5d6cf85
Compare
Hi @ebsmothers, could you please take a moment to review the code |
5d6cf85
to
03f782c
Compare
Hi @noemotiovon sorry for the delay! I will take a look tomorrow if that's alright. Until then I'll tag @RdoubleA and @joecummings in case either of them gets a minute to take a look |
03f782c
to
4478809
Compare
Hi @ebsmothers, when you have a moment, could you take a quick look at the recent changes I made? Your feedback would be greatly appreciated. Thank you! |
@@ -430,7 +435,7 @@ def _setup_model( | |||
|
|||
log.info(f"Model is initialized with precision {self._dtype}.") | |||
|
|||
if self._device.type == "cuda": | |||
if self._device.type in DeviceSupport.get_cuda_like_device_types(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe it's just me but I find this to be a bit confusing.. can we just scrap this method and use if self._device.type != "cpu"
in these places instead? I know it may not be as general but I think this is extra indirection that we don't really need
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you’re right! This encapsulation seems redundant at the moment, as cuda-like devices currently appear to be simply non-CPU devices.
torchtune/training/memory.py
Outdated
@@ -45,11 +45,11 @@ def set_activation_checkpointing( | |||
|
|||
def cleanup_before_training() -> None: | |||
""" | |||
Call gc collect, empty CUDA cache, and reset peak memory stats. | |||
Call gc collect, empty CUDA-like cache, and reset peak memory stats. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I think "device" should be sufficient here
Call gc collect, empty CUDA-like cache, and reset peak memory stats. | |
Call gc collect, empty device cache, and reset peak memory stats. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the suggestion! I’ll make the descriptions more appropriate.
@@ -50,6 +52,7 @@ def verify_bf16_support() -> bool: | |||
- CUDA compute capability >= 8 | |||
- NCCL is available and version >= 2.10 | |||
- MPS is available and torch was built with MPS | |||
- NPU is available and supports bf16 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: this is a bit redundant. Do we know the exact requirements for bf16 support on NPUs?
torchtune/utils/_device.py
Outdated
return False | ||
|
||
|
||
logger = get_logger("DEBUG") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we put this at the top of the file (i.e. just after imports but before function definitions)? That's where I'd expect to see it
torchtune/utils/_device.py
Outdated
device_type = get_device_support().device_type | ||
device_name = get_device_support().device_name |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: can just call get_device_support()
once
torchtune/utils/_device.py
Outdated
# Ensure device index matches assigned index when distributed training | ||
if device.index != local_rank: | ||
raise RuntimeError( | ||
f"You can't specify a device index when using distributed training. \ | ||
Device specified is {device} but was assigned cuda:{local_rank}" | ||
Device specified is {device} but was assigned cuda device:{local_rank}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this log message was not super clear before but this is also unclear. Maybe something like
Device specified is {device} but was assigned cuda device:{local_rank}" | |
Device specified is {device} but local rank is {local_rank}" |
(assuming that NPU devices also contain rank in their string representation?)
Btw a higher-level question here: it was mentioned in a previous comment that there were issues running some of the distributed training scripts on NPU. Did that all get sorted out?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, NPU devices also contain.
There are still some issues with adapting NPU for distributed scripts. I’m working on it and will include the updates in a future PR. Thanks for the reminder!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In that case we probably don't even need to compare device index against device count for NPU, right? (At least not until those changes land.) Though I think it's fine to leave this as is since handling NPU separately here may be messier
torchtune/utils/_device.py
Outdated
1. `device_type` (str): The type of device (e.g., "cpu", "cuda", "npu"). | ||
2. `device_name` (str): A user-friendly name for the device (e.g., "CPU", "GPU", "NPU"). | ||
3. `communication_backend` (str): Specifies the backend used for communication on this device (e.g., "gloo", "nccl", "hccl"). | ||
4. `cuda_like` (bool): Indicates whether the device is CUDA-like or not. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Discussed in another comment but I'm not sure we 100% need this field. At the very least I find its naming a bit unclear
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, I’ll remove this seemingly redundant attribute!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @noemotiovon for your patience! I left a handful more comments, please let me know if anything is unclear
Thank you for your review! Your feedback is very clear, and I will make the necessary code changes as soon as possible based on your suggestions. |
Hi @ebsmothers, I’ve made the code changes based on your suggestions; could you please review it again? Additionally:
Best regards |
return DeviceSupport.from_type(device_type) | ||
|
||
|
||
def get_torch_device() -> any: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think get_torch_device
is not descriptive enough and is too similar to our existing get_device
method. Given the name I would expect this to return a torch.device, but it's really returning a module/namespace, right? In that case maybe we could call it get_torch_device_namespace
or something?
"""Return the corresponding torch attribute based on the device type string. | ||
|
||
Returns: | ||
module: The corresponding torch module, or torch.cuda if not found. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Kinda related to my above comment (specifically it's why I suggested namespace in the function name instead of module).. I think module is a pretty overloaded term in PyTorch, and when I see this I think of nn.Module. Even though you're using it correctly, maybe we can say something like this instead to mitigate any potential confusion?
module: The corresponding torch module, or torch.cuda if not found. | |
module: The corresponding torch device namespace, or torch.cuda if not found. |
return getattr(torch, device_type) | ||
except AttributeError: | ||
logger.warning( | ||
f"Device Module '{device_type}' not found in torch, try to load torch.cuda." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar comment here
f"Device Module '{device_type}' not found in torch, try to load torch.cuda." | |
f"Device namespace '{device_type}' not found in torch, try to load torch.cuda." |
torchtune/utils/_device.py
Outdated
# Ensure device index matches assigned index when distributed training | ||
if device.index != local_rank: | ||
raise RuntimeError( | ||
f"You can't specify a device index when using distributed training. \ | ||
Device specified is {device} but was assigned cuda:{local_rank}" | ||
Device specified is {device} but was assigned cuda device:{local_rank}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In that case we probably don't even need to compare device index against device count for NPU, right? (At least not until those changes land.) Though I think it's fine to leave this as is since handling NPU separately here may be messier
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @noemotiovon for the updates! I left a couple more comments but I think this is pretty close now. It looks like a unit test is failing in CI though, can you take a look? Happy to provide any debugging pointers if you need
What does this PR do?
Overview
🚀This PR enables the users of
torhtune
to leverage the Ascend NPU for better performance in inferencing when GPU device is not available.This PR primarily addresses the initial refactoring of device-independent code. In upcoming changes, we’ll focus on further adjustments, using NPU as an example to refine each recipe and complete the remaining device-independent modifications. For now, this PR only touches on recipe lora_finetune_single_device and full_finetune_single_device.
For more details, see: [#1797].
Environment
Note
To properly install CANN, see [here] for more details.
The version of
torch-npu
should match that oftorch
, see [here] for more details.In addition,
torch_npu
has a pre-release version, 2.4.0 RC1, which is also the basis for this test. For more information, please visit [here].Examples
To start with, the library
torch_npu
should be correctly installed and imported. Part of the codes are showed below:torchtune/utils/_device_support.py
:Plus, there are some other places of the codes might be adjusted, which won't be too much.
Feel free to leave comments to guide me in further improvements 😊.